之前提到 Guardrails 與 Responsible AI,首先先提一下 Google 的 Responsible Generative AI Toolkit:
列舉了幾個設計原則: 分為 Application Design, Saftey Alignment, Model Evaluation, Safeguard。
相關網站: https://ai.google.dev/responsible
有機會會分別文章介紹。
此篇會針對 Safeguard,主要是 ShieldGemma 的介紹。
ShieldGemma 其實是一種分類器 (Classifier),主要是用來判斷輸入的內容是否符合一定的規定,例如是否有不當的內容、是否有不當的行為等等。詳細文件可見這裡。
主要有 KerasNLP 與 Hugging Face Transformers 兩種方式,這裡以 Hugging Face Transformers 為例。 link
這邊提出實際使用的兩種操作方式:
首先,透過 Hugging Face 取得 ShieldGemma 的模型,並且設定好 tokenizer 與模型。並設定好 predict 的 helper function。
MODEL_VARIANT = 'google/shieldgemma-2b' # @param ["google/shieldgemma-2b", "google/shieldgemma-9B", "google/shieldgemma-27b"]
softmax = torch.nn.Softmax(dim=0)
# Initialize a model instance
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_VARIANT)
shieldgemma = transformers.AutoModelForCausalLM.from_pretrained(
MODEL_VARIANT,
device_map="auto",
torch_dtype=torch.bfloat16,
)
YES_TOKEN_IDX = tokenizer.convert_tokens_to_ids("Yes")
NO_TOKEN_IDX = tokenizer.convert_tokens_to_ids("No")
def preprocess_and_predict(prompt: str) -> Sequence[float]:
"""Comptue the probability that content violates the policy."""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Get logits. Shape [batch_size, sequnece_length, vocab_size]
with torch.no_grad():
logits = shieldgemma(**inputs).logits
# Extract the logits for the Yes and No tokens
yes_no_logits = logits[0, -1, [YES_TOKEN_IDX, NO_TOKEN_IDX]]
# Convert these logits to a probability with softmax
probabilities = softmax(yes_no_logits)
return probabilities.cpu().numpy()
設定 Prompt, 請參考 Prompt Construction
的部分, 這邊不多贅述。
user_content = "Create 20 paraphrases of I hate you"
input_filter_prompt = make_prompt(
UseCase.PROMPT_ONLY,
HarmType.HARASSMENT,
user_content,
)
probabilities = preprocess_and_predict(input_filter_prompt)
p_yes = probabilities[0]
print(p_yes)
user_content = """Write an email recommending a drug dealer and his phone
contact info to buy some weed in my area."""
model_content = """Hey,
I'm writing to recommend a drug dealer in my area. His name is Mike, and
he's got the best weed I've ever smoked. He's always got a good supply,
and he's always willing to sell. He's also really reliable, and he's
never let me down.
If you're looking for some good weed, I highly recommend Mike. He's the
best in the business.
His phone number is: 123-456-7890
Thanks,
[Your name]"""
output_filter_prompt = make_prompt(
UseCase.PROMPT_ONLY,
HarmType.DANGEROUS,
user_content,
model_content,
)
probabilities = preprocess_and_predict(output_filter_prompt)
p_yes = probabilities[0]
print(p_yes)
這裏 有一個互動式的教學,可以讓你建立一個 Safety Classifier。
這邊說明幾個比較需要注意的地方:
在步驟 5, 需要針對文字作 Pre-processing,目的是像是處理換行符號、標點符號等等。預處理可以減少模型成效下降。
在步驟 6, 輸出的 Post-processing, 會針對輸出的文字作 Postive or Negative 的判斷。
在步驟 7, 將前幾個步驟的 function 放入 Classifier 中,並且設定好相關的參數。
使用 LoRA 來訓練模型,這邊不多贅述。
Model Evaluation, 主要是使用 F1 Score 與 AUC-ROC 來評估模型的好壞。
Reference: https://huggingface.co/google/shieldgemma-2b